"""Quantum-Classical Boundary Emergence Simulation

This module implements a computational model for studying the emergence of classical
behavior from quantum dynamics through recursive feedback stabilization. The simulation
demonstrates how wavefunction collapse and decoherence can arise from predictive
corrections that align quantum evolution with classical trajectories.

The core methodology follows the time-dependent Schrödinger equation with an
additional correction potential that creates a feedback loop between quantum
expectation values and classical predictions.
"""

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import os
from datetime import datetime
from scipy.fftpack import fft, ifft, fftfreq
from scipy.integrate import simpson as simps
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation

# === PHYSICAL AND NUMERICAL PARAMETERS ===
# Spatial discretization for finite difference approximation
GRID_SIZE = 512  # Number of spatial grid points
X = np.linspace(-10, 10, GRID_SIZE)  # Spatial domain [x_min, x_max]
DX = X[1] - X[0]  # Spatial step size: Δx

# Temporal discretization for time evolution
DT = 0.005  # Time step size: Δt (chosen to satisfy stability criteria)
N_STEPS = 1000  # Total number of time steps

# Physical constants (normalized units)
MASS = 1.0  # Particle mass m
HBAR = 1.0  # Reduced Planck constant ℏ

# === CORRECTION STRENGTH PARAMETER SWEEP ===
# λ values for recursive correction potential: V_corr = -λ(x - ⟨x⟩)(⟨x⟩ - x_classical)
# λ = 0: Pure quantum evolution (no correction)
# λ > 0: Increasing classical stabilization
LAMBDA_VALUES = [0.0, 0.01, 0.02, 0.05, 0.1]

# === DOUBLE-WELL POTENTIAL DEFINITION ===
def double_well_potential(x):
    """
    Define the double-well potential energy function.
    
    Physics: V(x) = (1/4)x⁴ - 2x²
    This creates a symmetric double-well with minima at x = ±2 and a local maximum at x = 0.
    The quartic term provides confinement while the quadratic term creates the barrier.
    
    Parameters:
    -----------
    x : numpy.ndarray
        Spatial coordinate array
    
    Returns:
    --------
    numpy.ndarray
        Potential energy values V(x)
    """
    return 0.25 * x**4 - 2 * x**2

# Evaluate potential on spatial grid
V = double_well_potential(X)

# === INITIAL QUANTUM STATE: SUPERPOSITION OF LOCALIZED WAVEPACKETS ===
def initial_wavefunction(x):
    """
    Create initial wavefunction as superposition of two Gaussian wavepackets.
    
    Physics: Ψ(x,0) = N[exp(-(x+2)²) + exp(-(x-2)²)]
    This represents a quantum superposition of two localized states centered at
    the potential minima (x = ±2). The system begins in a macroscopic superposition
    analogous to Schrödinger's cat paradox.
    
    Mathematical details:
    - Each Gaussian has width σ = 1/√2
    - The superposition creates quantum interference patterns
    - Normalization ensures ∫|Ψ|²dx = 1
    
    Parameters:
    -----------
    x : numpy.ndarray
        Spatial coordinate array
    
    Returns:
    --------
    numpy.ndarray (complex128)
        Normalized initial wavefunction Ψ(x,0)
    """
    # Left localized wavepacket centered at x = -2
    psi_left = np.exp(-(x + 2)**2)
    
    # Right localized wavepacket centered at x = +2
    psi_right = np.exp(-(x - 2)**2)
    
    # Quantum superposition (equal amplitude)
    psi = psi_left + psi_right
    
    # Normalize to unit probability: ∫|Ψ|²dx = 1
    norm = np.sqrt(simps(np.abs(psi)**2, x))
    psi /= norm
    
    return psi.astype(np.complex128)


# === PHYSICAL CONSTANTS AND DERIVED QUANTITIES ===
# These constants establish the natural units for the simulation
print(f"Physical parameters:")
print(f"Mass: {MASS}")
print(f"ħ: {HBAR}")
print(f"Potential wells at x = ±{np.sqrt(4):.1f}")
print(f"Barrier height: {double_well_potential(0.0):.1f}")
print(f"Well depth: {double_well_potential(2.0):.1f}")
print(f"Spatial resolution: {DX:.4f}")
print(f"Temporal resolution: {DT:.4f}")
print("=" * 30)

# === MOMENTUM SPACE REPRESENTATION ===
# Momentum grid for Fourier transform (split-step method)
# k = 2πn/(N·Δx) where n = -N/2, ..., N/2-1
K = fftfreq(GRID_SIZE, d=DX) * 2 * np.pi

# Kinetic energy operator in momentum space: T̂ = p²/(2m) = (ℏk)²/(2m)
K_SQUARED = (HBAR * K)**2 / (2 * MASS)

# === SHANNON ENTROPY CALCULATION ===
def compute_entropy(psi):
    """
    Calculate Shannon entropy of the wavefunction probability distribution.
    
    Physics: H = -Σ pᵢ log₂(pᵢ)
    where pᵢ = |Ψ(xᵢ)|²Δx is the probability density at grid point i.
    
    Shannon entropy quantifies the information content and serves as a measure
    of quantum coherence. Lower entropy indicates more localized (classical-like)
    states, while higher entropy indicates more delocalized (quantum) states.
    
    Parameters:
    -----------
    psi : numpy.ndarray (complex)
        Wavefunction Ψ(x,t)
    
    Returns:
    --------
    float
        Shannon entropy H in bits
    """
    # Calculate probability density |Ψ(x)|²
    prob_density = np.abs(psi)**2
    
    # Normalize to ensure Σpᵢ = 1
    prob_density /= np.sum(prob_density)
    
    # Calculate Shannon entropy with small regularization to avoid log(0)
    entropy = -np.sum(prob_density * np.log2(prob_density + 1e-12))
    
    return entropy

# === CLASSICAL TRAJECTORY CALCULATION ===
def classical_path(x0, v0, t, V_func):
    """
    Calculate classical trajectory using Newton's equations of motion.
    
    Physics: Newton's second law in one dimension:
    m(d²x/dt²) = -dV/dx = F(x)
    
    This uses the Verlet integration scheme for numerical stability:
    vₙ₊₁ = vₙ + aₙΔt
    xₙ₊₁ = xₙ + vₙ₊₁Δt
    
    where acceleration a = F/m = -(1/m)(dV/dx)
    
    Parameters:
    -----------
    x0 : float
        Initial position
    v0 : float
        Initial velocity
    t : float
        Final time
    V_func : function
        Potential energy function V(x)
    
    Returns:
    --------
    float
        Classical position at time t
    """
    x, v = x0, v0
    
    # Integrate Newton's equations numerically
    for _ in range(int(t / DT)):
        # Calculate force: F = -dV/dx (interpolated to current position)
        force_per_mass = -np.interp(x, X, np.gradient(V_func(X), DX))
        
        # Update velocity: v = v + aΔt
        v += force_per_mass * DT
        
        # Update position: x = x + vΔt
        x += v * DT
    
    return x

# === WIGNER FUNCTION COMPUTATION ===
def compute_wigner(psi, x):
    """
    Calculate the Wigner quasi-probability distribution in phase space.
    
    Physics: The Wigner function W(x,p) provides a phase-space representation
    of the quantum state. It's defined as:
    
    W(x,p) = (1/πℏ) ∫ Ψ*(x+s)Ψ(x-s) exp(-2ips/ℏ) ds
    
    The Wigner function can take negative values (indicating quantum interference)
    but integrates to give correct position and momentum distributions.
    Classical states have positive Wigner functions.
    
    Parameters:
    -----------
    psi : numpy.ndarray (complex)
        Wavefunction Ψ(x)
    x : numpy.ndarray
        Position grid
    
    Returns:
    --------
    numpy.ndarray
        Wigner function W(x,p) on (position, momentum) grid
    """
    N = len(x)
    wigner = np.zeros((N, N))
    
    # Compute Wigner function at each phase space point (x,p)
    for i, xi in enumerate(x):
        for j, pj in enumerate(K):
            # Wigner kernel: exp(-2ipj(x-xi)/ℏ) * Ψ*(x) * Ψ(x shifted)
            kernel = np.exp(-2j * pj * (x - xi)) * np.conj(psi) * np.roll(psi, i - N//2)
            
            # Integrate over position to get W(xi, pj)
            wigner[j, i] = simps(kernel, x).real
    
    return wigner

# === MAIN SIMULATION FUNCTION ===
def run_simulation(lambda_correction):
    """
    Execute the full quantum-classical boundary emergence simulation.
    
    This function implements the split-step Fourier method for solving the
    time-dependent Schrödinger equation with recursive feedback corrections.
    
    Algorithm:
    1. Initialize wavefunction in superposition state
    2. For each time step:
       a) Apply potential evolution: exp(-iVΔt/ℏ)
       b) Transform to momentum space via FFT
       c) Apply kinetic evolution: exp(-iT̂Δt/ℏ)
       d) Transform back to position space
       e) Apply potential evolution again
       f) Calculate recursive correction and apply
       g) Record diagnostics
    3. Generate output plots and animations
    
    Parameters:
    -----------
    lambda_correction : float
        Strength of recursive correction (λ parameter)
    """
    # Create timestamped output directory
    timestamp = datetime.now().strftime(f"%Y-%m-%d_%H-%M-%S_lambda_{lambda_correction:.3f}")
    results_dir = f"classical_emergence_lambda_{lambda_correction:.3f}_{timestamp}"
    os.makedirs(results_dir, exist_ok=True)

    # Initialize quantum state
    psi = initial_wavefunction(X)
    
    # Storage arrays for time series data
    entropies = []     # Shannon entropy H(t)
    positions = []     # Wavepacket center ⟨x⟩(t)
    uncertainties = [] # Position uncertainty Δx(t)
    waveframes = []    # Wavefunction snapshots for animation

    # === TIME EVOLUTION LOOP ===
    for step in range(N_STEPS):
        # SPLIT-STEP FOURIER METHOD for solving time-dependent Schrödinger equation
        # Evolution operator: U = exp(-iHΔt/ℏ) ≈ exp(-iVΔt/2ℏ)exp(-iT̂Δt/ℏ)exp(-iVΔt/2ℏ)
        
        # Step 1: Apply half-step potential evolution in position space
        # Ψ(x,t+Δt/2) = exp(-iV(x)Δt/2ℏ) Ψ(x,t)
        psi *= np.exp(-1j * V * DT / (2 * HBAR))
        
        # Step 2: Transform to momentum space using FFT
        psi_k = fft(psi)
        
        # Step 3: Apply kinetic evolution in momentum space
        # Ψ(k,t+Δt) = exp(-iℏk²Δt/2m) Ψ(k,t)
        psi_k *= np.exp(-1j * K_SQUARED * DT / HBAR)
        
        # Step 4: Transform back to position space
        psi = ifft(psi_k)
        
        # Step 5: Apply second half-step potential evolution
        psi *= np.exp(-1j * V * DT / (2 * HBAR))
        
        # Step 6: Renormalize to conserve probability
        norm = np.sqrt(simps(np.abs(psi)**2, X))
        psi /= norm

        # === RECURSIVE CORRECTION MECHANISM ===
        # Calculate quantum expectation value: ⟨x⟩ = ∫ x|Ψ(x,t)|²dx
        center = simps(X * np.abs(psi)**2, X)
        
        # Calculate expected classical position at current time
        expected_classical = classical_path(0.0, 0.0, step * DT, double_well_potential)
        
        # Correction strength proportional to deviation from classical path
        correction_strength = -lambda_correction * (center - expected_classical)
        
        # Apply correction potential: V_corr(x,t) = correction_strength * (x - ⟨x⟩)
        # This creates a restoring force toward the classical trajectory
        V_corr = correction_strength * (X - center)
        psi *= np.exp(-1j * V_corr * DT / HBAR)

        # === DIAGNOSTIC MEASUREMENTS ===
        if step % 10 == 0:  # Sample every 10 steps to reduce computational cost
            # Calculate Shannon entropy as coherence measure
            entropy = compute_entropy(psi)
            entropies.append(entropy)
            
            # Record wavepacket center position
            positions.append(center)
            
            # Calculate position uncertainty: Δx = √(⟨x²⟩ - ⟨x⟩²)
            x_mean = center
            x2_mean = simps((X**2) * np.abs(psi)**2, X)
            dx = np.sqrt(x2_mean - x_mean**2)
            uncertainties.append(dx)
            
            # Store wavefunction probability density for animation
            waveframes.append(np.abs(psi)**2)

    # === FINAL STATE ANALYSIS ===
    # Calculate Wigner function for phase-space representation
    wigner = compute_wigner(psi, X)

    # === GENERATE OUTPUT VISUALIZATIONS ===
    
    # Plot 1: Shannon entropy evolution
    # Shows decoherence dynamics - decreasing entropy indicates classical emergence
    plt.figure(figsize=(10, 6))
    time_points = np.arange(0, N_STEPS, 10) * DT
    plt.plot(time_points, entropies, 'b-', linewidth=2)
    plt.title(f"Shannon Entropy Evolution (\u03bb = {lambda_correction:.3f})")
    plt.xlabel("Time t")
    plt.ylabel("Shannon Entropy H(t) [bits]")
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(results_dir, "entropy_over_time.png"), dpi=300, bbox_inches='tight')

    # Plot 2: Wavepacket center trajectory
    # Shows how quantum expectation value evolves toward classical trajectory
    plt.figure(figsize=(10, 6))
    plt.plot(time_points, positions, 'r-', linewidth=2, label='Quantum ⟨x⟩(t)')
    
    # Compare with pure classical trajectory
    classical_traj = [classical_path(0.0, 0.0, t, double_well_potential) for t in time_points]
    plt.plot(time_points, classical_traj, 'k--', linewidth=1, label='Classical x(t)')
    
    plt.title(f"Center Position Evolution (\u03bb = {lambda_correction:.3f})")
    plt.xlabel("Time t")
    plt.ylabel("Position ⟨x⟩(t)")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(results_dir, "position_over_time.png"), dpi=300, bbox_inches='tight')

    # Plot 3: Position uncertainty evolution
    # Shows quantum spreading vs classical localization
    plt.figure(figsize=(10, 6))
    plt.plot(time_points, uncertainties, 'g-', linewidth=2)
    plt.title(f"Position Uncertainty Evolution (\u03bb = {lambda_correction:.3f})")
    plt.xlabel("Time t")
    plt.ylabel("Position Uncertainty Δx(t)")
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(results_dir, "uncertainty_over_time.png"), dpi=300, bbox_inches='tight')

    # Plot 4: Final wavefunction probability density
    # Shows final state localization
    plt.figure(figsize=(10, 6))
    plt.plot(X, np.abs(psi)**2, 'purple', linewidth=2)
    plt.fill_between(X, np.abs(psi)**2, alpha=0.3, color='purple')
    plt.title(f"Final Wavefunction Probability Density (\u03bb = {lambda_correction:.3f})")
    plt.xlabel("Position x")
    plt.ylabel("Probability Density |Ψ(x,t_final)|²")
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(results_dir, "final_wavefunction.png"), dpi=300, bbox_inches='tight')

    # Plot 5: Wigner function phase space distribution
    # Shows quantum-classical transition in phase space
    plt.figure(figsize=(10, 8))
    plt.imshow(wigner, extent=[X[0], X[-1], K[0], K[-1]], 
               aspect='auto', origin='lower', cmap='seismic')
    plt.colorbar(label="Wigner Function W(x,p)")
    plt.title(f"Wigner Phase Space Distribution (\u03bb = {lambda_correction:.3f})")
    plt.xlabel("Position x")
    plt.ylabel("Momentum p")
    plt.savefig(os.path.join(results_dir, "wigner_distribution.png"), dpi=300, bbox_inches='tight')

    # Plot 6: 3D Wigner function surface
    # Three-dimensional visualization of phase space structure
    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(111, projection='3d')
    X_grid, K_grid = np.meshgrid(X, K)
    surf = ax.plot_surface(X_grid, K_grid, wigner, cmap=cm.seismic, 
                          linewidth=0, antialiased=True, alpha=0.8)
    ax.set_title(f"3D Wigner Distribution (\u03bb = {lambda_correction:.3f})")
    ax.set_xlabel("Position x")
    ax.set_ylabel("Momentum p")
    ax.set_zlabel("W(x, p)")
    plt.colorbar(surf, shrink=0.5, aspect=5)
    plt.savefig(os.path.join(results_dir, "wigner_3D_surface.png"), dpi=300, bbox_inches='tight')

    # Animation: Wavefunction evolution over time
    # Shows real-time quantum dynamics and classical emergence
    fig, ax = plt.subplots(figsize=(10, 6))
    line, = ax.plot([], [], 'b-', linewidth=2)
    ax.set_xlim(X[0], X[-1])
    ax.set_ylim(0, np.max(waveframes) * 1.1)
    ax.set_title(f"Wavefunction Evolution (\u03bb = {lambda_correction:.3f})")
    ax.set_xlabel("Position x")
    ax.set_ylabel("Probability Density |Ψ(x,t)|²")
    ax.grid(True, alpha=0.3)

    def init():
        """Initialize animation."""
        line.set_data([], [])
        return line,

    def update(frame):
        """Update animation frame."""
        line.set_data(X, waveframes[frame])
        ax.set_title(f"Wavefunction Evolution (\u03bb = {lambda_correction:.3f}, t = {frame*10*DT:.2f})")
        return line,

    # Create and save animation
    ani = animation.FuncAnimation(fig, update, frames=len(waveframes), 
                                init_func=init, blit=True, interval=50)
    ani.save(os.path.join(results_dir, "wavefunction_evolution.gif"), 
             writer='pillow', fps=15, dpi=150)
    plt.close(fig)  # Close figure to free memory

# === PARAMETER SWEEP EXECUTION ===
if __name__ == "__main__":
    """
    Execute parameter sweep over correction strengths.
    
    This systematic study demonstrates how increasing λ values lead to:
    1. Enhanced classical behavior (reduced entropy)
    2. Trajectory stabilization (closer to classical path)
    3. Reduced quantum uncertainty
    4. Phase space localization (positive Wigner function)
    """
    print("Quantum-Classical Boundary Emergence Simulation")
    print("=" * 50)
    print(f"Grid size: {GRID_SIZE}")
    print(f"Time steps: {N_STEPS}")
    print(f"Time step size: {DT}")
    print(f"Parameter sweep: λ ∈ {LAMBDA_VALUES}")
    print("=" * 50)
    
    for i, lambda_val in enumerate(LAMBDA_VALUES):
        print(f"Running simulation {i+1}/{len(LAMBDA_VALUES)}: λ = {lambda_val:.3f}")
        run_simulation(lambda_val)
        print(f"Completed λ = {lambda_val:.3f}")
    
    print("\nAll simulations completed successfully!")
    print("Check output directories for results.")
